The focus of this notebook is on the development of a high performance method to retain only a subset of tips within a tree. It can be used to pull a subtree with the constraint that the subtree spans some tips and is not restricted to internal nodes. This method operates on an array-based postorder representation of a tree as produced by to_array.

The motivation for this notebook is to minimize the time and space requirements for grabbing a subtree directly from the array representation. These arrays are the primary data structure used by Fast UniFrac. In developing a parallized version of Fast UniFrac, it became apparent that this operation would be performed many times and the space and time needed for the implicit copy of the tree were a concern. The performance of Fast UniFrac is tightly tied to the number of nodes within the tree being operated on, and is thus being able to minimize the number of nodes represented is critical for its performance.

scikit-bio's TreeNode does contain a shear, however the method only operates on TreeNode and is very expensive in time and space on large trees (i.e., millions of tips) as it must make a copy of the tree. ete2 also provides a comparable method, prune, however this method operates in-place on the tree and thus is not suitable for situations where multiple subtrees must be derived. Both ete2 and scikit-bio will be subject to large memory requirements relative a tree array, on the order of 10s of GB to 10s of MB) due to the rich nature of the objects (which is generally a very very good thing).


In [313]:
from skbio import TreeNode

def shear(indexed, to_keep):
    """Shear off nodes from a tree array
    
    Parameters
    ----------
    indexed : dict
        The result of TreeNode.to_array
    to_keep : set
        The tip IDs of the tree to keep
        
    Returns
    -------
    dict
        A TreeNode.to_array like dict with the exception that "id_index" is not
        provided, and any extraneous attributes formerly included are not 
        passed on.
    
    Notes
    -----
    Unlike TreeNode.shear, this method does not prune (i.e., collapse single
    descendent nodes). This is an open development target.
    
    This method assumes that to_keep is a subset of names in the tree.
    
    The order of the nodes remains unchanged.
    """
    to_keep = set(to_keep)
    
    # nodes to keep mask
    mask = np.zeros(len(indexed['id']), dtype=np.bool)

    # set any tips marked "to_keep"
    tips_to_keep = [i for i, n in enumerate(indexed['name']) if n in to_keep]
    mask[np.asarray(tips_to_keep)] = True

    # perform a post-order traversal and identify any nodes that should be 
    # retained
    new_child_index = []
    for node_idx, child_left, child_right in indexed['child_index']:
        being_kept = mask[child_left:child_right + 1]

        # NOTE: the second clause is an explicit test to keep the root node. This 
        # may not be necessary and may be a remenant of mucking around.
        if being_kept.sum() >= 1 or node_idx == indexed['id'][-1]:
            mask[node_idx] = True

    # we now know what nodes to keep, so we can create new IDs for assignment
    new_ids = np.arange(mask.sum(), dtype=int)
    
    # construct a map that associates old node IDs to the new IDs
    id_map = {i_old: i_new for i_old, i_new in zip(indexed['id'][mask], new_ids)}
    #new_ids = np.arange(mask.sum(), dtype=int)
    id_map = np.zeros(len(indexed['id']), dtype=int)
    id_map[mask] = new_ids

    # perform another post-order traversal to construct the new child index arrays
    # which provide index positions of the desecendents of a given internal node.
    for node_idx, child_left, child_right in indexed['child_index']:
        being_kept = mask[child_left:child_right + 1]

        # NOTE: the second clause is an explicit test to keep the root node. This 
        # may not be necessary and may be a remenant of mucking around.
        if being_kept.sum() >= 1 or node_idx == indexed['id'][-1]:
            new_id = id_map[node_idx]
            child_indices = indexed['id'][child_left:child_right + 1][being_kept]
            left_child = id_map[child_indices[0]]
            right_child = id_map[child_indices[-1]]
            new_child_index.append([new_id, left_child, right_child])

    new_child_index = np.asarray(new_child_index)

    return {'child_index': new_child_index,
            'length': indexed['length'][mask],
            'name': indexed['name'][mask],
            'id': new_ids}

def collapse(indexed):
    lengths = indexed['length'].copy()
    names = indexed['name']
    id_ = indexed['id']
    child_index = indexed['child_index']
        
    mask = np.ones(len(lengths), dtype=bool)
    parent = np.zeros(len(lengths), dtype=int)
    
    for node_idx, left, right in child_index:
        parent[left:right+1] = node_idx
    parent[-1] = -1  # root
    
    for node_idx, left, right in child_index[::-1]:
        if left == right:
            if parent[node_idx] == -1: 
                # if there is a single descendent from the root
                lengths[left] = np.nan
                parent[left] = -1
            else:
                lengths[left] += lengths[node_idx]
                parent[left] = parent[node_idx]
                
            mask[node_idx] = False

    if mask.sum() == len(mask):
        return indexed
  
    # for sorting, set root to largest parent index
    root_index = parent.max() + 1
    parent = np.where(parent == -1, root_index, parent)
    
    parent_sorted_indices = np.argsort(parent)
    left = parent_sorted_indices[0]
    left_parent = parent[left]
    last_right = left
    
    new_index_offset = 0
    new_index = np.zeros(((np.diff(parent[parent_sorted_indices]) != 0).sum(), 3), dtype=int)

    new_ids = np.arange(mask.sum(), dtype=int)
    new_id_lookup = np.zeros(len(id_), dtype=int)
    new_id_lookup[parent_sorted_indices[mask[parent_sorted_indices]]] = new_ids
 
    for right in parent_sorted_indices[mask[parent_sorted_indices]]:
        right_parent = parent[right]

        if left_parent != right_parent:
            new_left_id, new_right_id = sorted([new_id_lookup[left], new_id_lookup[last_right]])
            new_index[new_index_offset] = np.array([new_id_lookup[left_parent], 
                                                    new_left_id, 
                                                    new_right_id])
            left = right
            left_parent = parent[left]
            new_index_offset += 1

        last_right = right

    result = {'child_index': new_index,
              'length': lengths[parent_sorted_indices[mask[parent_sorted_indices]]],
              'name': names[parent_sorted_indices[mask[parent_sorted_indices]]],
              'id': new_ids}

    return(result)
    
class FromArrayTreeNode(TreeNode):
    """Subclass for support to read from an array"""
    @classmethod
    def from_array(cls, tree_as_array):
        nodes = [cls(name=n, length=(l if not np.isnan(l) else None))
                 for n, l in zip(tree_as_array['name'], 
                                 tree_as_array['length'])]
        
        for parent_idx, left_child, right_child in tree_as_array['child_index']:
            parent = nodes[parent_idx]
            parent.extend(nodes[left_child:right_child+1])
        
        return nodes[-1]

In [314]:
from unittest import TestCase
import numpy.testing as npt
import numpy as np


class ShearTests(TestCase):
    def test_shear_identity(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        to_keep = {'a', 'b', 'd', 'e'}
        obs = shear(tree, to_keep)
        npt.assert_equal(obs['length'], tree['length'])
        npt.assert_equal(obs['id'], tree['id'])
        npt.assert_equal(obs['name'], tree['name'])
        npt.assert_equal(obs['child_index'], tree['child_index'])

    def test_shear_drop_clade(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        to_keep = {'d', 'e'}
        exp = {'length': np.array([4, 5, 6, np.nan]),
               'name': np.array(['d', 'e', 'f', 'root']),
               'id': np.array([0, 1, 2, 3]),
               'child_index': np.array([[2, 0, 1],
                                        [3, 2, 2]])}

        obs = shear(tree, to_keep)
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])

    def test_shear_complex_identity(self):
        tree = TreeNode.read(['((a:1,b:2,c:3)d:4,'
                              '(((e:4)f:5)g:6)h:7,'
                              '((((i:8,j:9)k:10,(l:11)m:12)n:13)o:14,p:15)q:16'
                              ')root;'])
        to_keep = {n.name for n in tree.tips()}
        exp = {'length': np.array([1, 2, 3, 4, 5, 6, 8, 9, 11, 10, 12, 13, 14,
                                   15, 4, 7, 16,  np.nan]),
               'id': np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
                               14, 15, 16, 17]),
               'child_index': np.array([[4, 3, 3],
                                        [5, 4, 4],
                                        [9, 6, 7],
                                        [10, 8, 8],
                                        [11, 9, 10],
                                        [12, 11, 11],
                                        [14, 0, 2],
                                        [15, 5, 5],
                                        [16, 12, 13],
                                        [17, 14, 16]]),
               'name': np.array(['a', 'b', 'c', 'e', 'f', 'g', 'i', 'j', 'l',
                                 'k', 'm', 'n', 'o', 'p', 'd', 'h', 'q',
                                 'root'], dtype=object)}

        tree_array = tree.to_array()

        obs = shear(tree_array, to_keep)
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])

    def test_shear_complex(self):
        tree = TreeNode.read(['((a:1,b:2,c:3)d:4,'
                              '(((e:4)f:5)g:6)h:7,'
                              '((((i:8,j:9)k:10,(l:11)m:12)n:13)o:14,p:15)q:16'
                              ')root;'])
        to_keep = {'b', 'c', 'i', 'l', 'p'}
        exp = {'length': np.array([2, 3, 8, 11, 10, 12, 13, 14, 15, 4, 16,
                                   np.nan]),
               'id': np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
               'child_index': np.array([[4, 2, 2],  # k
                                        [5, 3, 3],  # m
                                        [6, 4, 5],  # n
                                        [7, 6, 6],  # o
                                        [9, 0, 1],  # d
                                        [10, 7, 8],  # q
                                        [11, 9, 10]]),  # root
               'name': np.array(['b', 'c', 'i', 'l', 'k', 'm', 'n', 'o', 'p',
                                 'd', 'q', 'root'], dtype=object)}

        tree_array = tree.to_array()

        obs = shear(tree_array, to_keep)
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])
        
        
class CollapseTests(TestCase):
    def test_collapse_identity(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        obs = collapse(tree)
        npt.assert_equal(obs['length'], tree['length'])
        npt.assert_equal(obs['id'], tree['id'])
        npt.assert_equal(obs['name'], tree['name'])
        npt.assert_equal(obs['child_index'], tree['child_index'])

    def test_collapse_identity_dropped_clade(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        to_keep = {'d', 'e'}
        exp = {'length': np.array([4, 5, np.nan]),
               'name': np.array(['d', 'e', 'f']),
               'id': np.array([0, 1, 2]),
               'child_index': np.array([[2, 0, 1]])}
                                    

        obs = collapse(shear(tree, to_keep))
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])
        
    def test_collapse(self):
        tree = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;']).to_array()
        to_keep = {'a', 'e'}
        exp = {'length': np.array([4, 11, np.nan]),
               'name': np.array(['a', 'e', 'root']),
               'id': np.array([0, 1, 2]),
               'child_index': np.array([[2, 0, 1]])}

        obs = collapse(shear(tree, to_keep))
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])
        
    def test_collapse_complex(self):
        tree = TreeNode.read(['((a:1,b:2,c:3)d:4,'
                              '(((e:4)f:5)g:6)h:7,'
                              '((((i:8,j:9)k:10,(l:11)m:12)n:13)o:14,p:15)q:16'
                              ')root;'])
        to_keep = {'b', 'c', 'i', 'l', 'p'}
        exp = {'length': np.array([18, 23, 2, 3, 27, 15, 4, 16, np.nan]),
               'id': np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
               'child_index': np.array([[4, 0, 1],  # n
                                        [6, 2, 3],  # d
                                        [7, 4, 5],  # q
                                        [8, 6, 7]]),  # root
               'name': np.array(['i', 'l', 'b', 'c', 'n', 'p', 'd', 'q', 'root'], dtype=object)}

        tree_array = tree.to_array()

        obs = collapse(shear(tree_array, to_keep))
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])
        
    def test_collapse_to_root(self):
        tree = TreeNode.read(['(((((a:1)b:2)c:3)d:4)e:5)root;'])
        exp = {'length': np.array([np.nan]),
               'name': np.array(['a']),
               'id': np.array([0]),
               'child_index': np.zeros((0, 3))}
        
        obs = collapse(tree.to_array())
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])
        
    def test_multiple_collapse(self):
        tree = TreeNode.read(['(((((a:1)b:2)c:3)d:4)e:5,f:6)root;'])
        exp = {'length': np.array([15, 6, np.nan]),
               'name': np.array(['a', 'f', 'root']),
               'id': np.array([0, 1, 2]),
               'child_index': np.array([[2, 0, 1]])}
        
        obs = collapse(tree.to_array())
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])

    def test_multiple_multiple_collapse(self):
        tree = TreeNode.read(['((((a:1)b:2)c:3)d:4,(((e:5)f:6)g:7)h:8)root;'])
        exp = {'length': np.array([10, 26, np.nan]),
               'name': np.array(['a', 'e', 'root']),
               'id': np.array([0, 1, 2]),
               'child_index': np.array([[2, 0, 1]])}

        obs = collapse(tree.to_array())
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        npt.assert_equal(obs['child_index'], exp['child_index'])
        
    def test_reordering(self):
        tree = TreeNode.read(['((a:1)c:3,(d:4,e:5)f:6)root;'])
        exp = {'length': np.array([4, 5, 4, 6, np.nan]),
               'name': np.array(['d', 'e', 'a', 'f', 'root']),
               'id': np.array([0, 1, 2, 3, 4]),
               'child_index': np.array([[3, 0, 1],
                                        [4, 2, 3]])}

        obs = collapse(tree.to_array())
        npt.assert_equal(obs['child_index'], exp['child_index'])
        npt.assert_equal(obs['length'], exp['length'])
        npt.assert_equal(obs['id'], exp['id'])
        npt.assert_equal(obs['name'], exp['name'])
        
class TreeNodeFromArray(TestCase):
    def test_from_array_simple(self):
        exp = TreeNode.read(['((a:1,b:2)c:3,(d:4,e:5)f:6)root;'])
        obs = FromArrayTreeNode.from_array(exp.to_array())
        
        obs.assign_ids()
        exp.assign_ids()
        
        for o, e in zip(obs.traverse(), exp.traverse()):
            self.assertEqual(o.name, e.name)
            self.assertEqual(o.length, e.length)
            self.assertEqual(o.id, e.id)
            
    def test_from_array_complex(self):
        exp = TreeNode.read(['((a:1,b:2,c:3)d:4,'
                             '(((e:4)f:5)g:6)h:7,'
                             '((((i:8,j:9)k:10,(l:11)m:12)n:13)o:14,p:15)q:16'
                             ')root;'])
        obs = FromArrayTreeNode.from_array(exp.to_array())
        
        obs.assign_ids()
        exp.assign_ids()

        for o, e in zip(obs.traverse(), exp.traverse()):
            self.assertEqual(o.name, e.name)
            self.assertEqual(o.length, e.length)
            self.assertEqual(o.id, e.id)

In [315]:
# adapted from http://amodernstory.com/2015/06/28/running-unittests-in-the-ipython-notebook/
from unittest import TestLoader, TextTestRunner

test_loader = TestLoader()
runner = TextTestRunner()

tests = (ShearTests, CollapseTests, TreeNodeFromArray)

for testcase in tests:
    suite = test_loader.loadTestsFromModule(testcase())
    runner.run(suite)


....
----------------------------------------------------------------------
Ran 4 tests in 0.009s

OK
........
----------------------------------------------------------------------
Ran 8 tests in 0.018s

OK
..
----------------------------------------------------------------------
Ran 2 tests in 0.004s

OK

In [316]:
from random import shuffle
from time import time

#gg = TreeNode.read('/Users/daniel/miniconda3/envs/qiime191/lib/python2.7/site-packages/qiime_default_reference/gg_13_8_otus/trees/97_otus.tree')

def bench(tree, n, i):
    tipnames = [n.name for n in tree.tips()]
    tree_array = tree.to_array()
    
    for i in range(i):
        shuffle(tipnames)
        tips_to_keep = tipnames[:n]
        
        skbio_shear_start = time()
        skbio_shear = tree.shear(tips_to_keep)
        skbio_shear_time = time() - skbio_shear_start
        
        array_shear_start = time()
        array_shear = shear(tree_array, tips_to_keep)
        array_shear_time = time() - array_shear_start
        
        array_collapse_start = time()
        array_collapse = collapse(array_shear)
        array_collapse_time = time() - array_collapse_start
        
        print("skbio shear: %0.2f" % skbio_shear_time)
        print("array shear: %0.2f" % array_shear_time)
        print("array collapse: %0.2f" % array_collapse_time)
        print("array total: %0.2f" % (array_collapse_time + array_shear_time))
        print()
        
        array_recovered = FromArrayTreeNode.from_array(array_collapse)
        array_recovered_names = [n.name or '' for n in array_recovered.traverse()]
        array_recovered_lengths = [n.length or 0.0 for n in array_recovered.traverse()]

        skbio_names = [n.name or '' for n in skbio_shear.traverse()]
        skbio_lengths = [n.length or 0.0 for n in skbio_shear.traverse()]

        assert sorted(array_recovered_names) == sorted(skbio_names)
        assert sorted(array_recovered_lengths) == sorted(skbio_lengths)

bench(gg, 10000, 2)


skbio shear: 15.52
array shear: 1.01
array collapse: 0.22
array total: 1.23

skbio shear: 17.59
array shear: 1.12
array collapse: 0.23
array total: 1.34


In [284]:


In [304]:
tipnames = [n.name for n in gg.tips()]
tree_array = gg.to_array()
#shuffle(tipnames)
tips_to_keep = tipnames[:1000]
#skbio_shear = gg.shear(tips_to_keep)
#array_collapse = collapse(array_shear)
x = %prun -r array_shear = shear(tree_array, set(tips_to_keep))


 

In [317]:
%timeit gg.to_array()


1 loops, best of 3: 920 ms per loop

In [318]:
%timeit FromArrayTreeNode.from_array(tree_array)


1 loops, best of 3: 1.05 s per loop

In [ ]: